Oxford Pets with Bounding Boxes¶

In [1]:
import os
import xml.etree.ElementTree as ET
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

class PetDetectionDataset(Dataset):
    def __init__(self, root, transform=None):
        """
        Custom dataset for Oxford-IIIT Pet:
        - Loads images from './data/oxford-iiit-pet/images/'
        - Reads bounding boxes from './data/oxford-iiit-pet/annotations/xmls/'
        - Uses the filename to determine if the pet is a cat (0) or a dog (1).
        """
        self.image_dir = os.path.join(root, "oxford-iiit-pet", "images")
        self.annotation_dir = os.path.join(root, "oxford-iiit-pet", "annotations", "xmls")
        self.transform = transform

        # Get list of valid files (only those with a corresponding XML file)
        self.image_files = []
        self.bboxes = []
        self.labels = []

        for xml_file in os.listdir(self.annotation_dir):
            if xml_file.endswith(".xml"):
                image_name = xml_file.replace(".xml", ".jpg")  # Image filename
                image_path = os.path.join(self.image_dir, image_name)
                xml_path = os.path.join(self.annotation_dir, xml_file)

                # Ensure image file exists
                if os.path.exists(image_path):
                    # Parse XML file to get bounding box
                    bbox = self.parse_xml(xml_path)
                    if bbox:
                        self.image_files.append(image_path)
                        self.bboxes.append(bbox)

                        # Extract breed name from filename
                        breed_name = "_".join(image_name.split("_")[:-1])  # Extract breed name
                        label = 0 if breed_name.islower() else 1  # Cat if lowercase, Dog if capitalized
                        self.labels.append(label)

    def parse_xml(self, xml_file):
        """Extract bounding box coordinates from the XML annotation file."""
        tree = ET.parse(xml_file)
        root = tree.getroot()
        bbox = None
        for obj in root.findall("object"):
            bndbox = obj.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)
            bbox = [xmin, ymin, xmax, ymax]  # Format: (xmin, ymin, xmax, ymax)
            break  # Only take the first object (each image should have one pet)
        return bbox

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        """Loads an image, bounding box, and label."""
        image_path = self.image_files[idx]
        bbox = torch.tensor(self.bboxes[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        # Load the image
        image = Image.open(image_path).convert("RGB")
        original_w, original_h = image.size  # Get original image size

        # Apply transformations (resize to 224x224)
        if self.transform:
            image = self.transform(image)

        # Normalize bounding box coordinates relative to original image dimensions
        bbox[0] /= original_w  # Normalize xmin
        bbox[1] /= original_h  # Normalize ymin
        bbox[2] /= original_w  # Normalize xmax
        bbox[3] /= original_h  # Normalize ymax

        return image, bbox, label
In [2]:
# Define image transformations
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

# Create dataset
data_root = "./data"
dataset = PetDetectionDataset(root=data_root, transform=transform)

# Split into training and validation sets (80/20)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Check dataset size
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")
Train size: 2948, Validation size: 738
In [3]:
# Get a sample from the training loader
for images, bboxes, labels in train_loader:
    sample_img = images[0]
    sample_bbox = bboxes[0]  # Format: (xmin, ymin, xmax, ymax)
    sample_label = labels[0].item()
    
    print(f"Label: {'Cat' if sample_label == 1 else 'Dog'}")
    print(f"Bounding Box: {sample_bbox.numpy()}")

    # Convert tensor image back to numpy for visualization
    img_np = sample_img.permute(1, 2, 0).numpy()
    plt.figure(figsize=(6, 6))
    plt.imshow(img_np)

    # Denormalize bbox coordinates (based on resized 224x224 image)
    xmin = sample_bbox[0] * 224
    ymin = sample_bbox[1] * 224
    xmax = sample_bbox[2] * 224
    ymax = sample_bbox[3] * 224

    # Compute width and height
    width = xmax - xmin
    height = ymax - ymin

    # Draw bounding box
    ax = plt.gca()
    rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor="r", facecolor="none")
    ax.add_patch(rect)
    plt.title("Sample Image with Computed Bounding Box")
    plt.axis("off")
    plt.show()
    break  # Only display one sample
Label: Dog
Bounding Box: [0.204      0.08408409 0.81       0.8918919 ]
No description has been provided for this image
In [4]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualize_grid(dataloader, num_images=9):
    """
    Display a 3x3 grid of images with bounding boxes and class labels.
    """
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))  # 3x3 grid
    axes = axes.flatten()  # Flatten to iterate easily

    count = 0
    for images, bboxes, labels in dataloader:
        for i in range(min(num_images, len(images))):
            if count >= num_images:
                break
            
            sample_img = images[i]
            sample_bbox = bboxes[i]  # Format: (xmin, ymin, xmax, ymax)
            sample_label = labels[i].item()

            # Convert tensor image back to numpy for visualization
            img_np = sample_img.permute(1, 2, 0).numpy()

            # Denormalize bbox coordinates (based on resized 224x224 image)
            xmin = sample_bbox[0] * 224
            ymin = sample_bbox[1] * 224
            xmax = sample_bbox[2] * 224
            ymax = sample_bbox[3] * 224

            # Compute width and height
            width = xmax - xmin
            height = ymax - ymin

            # Plot the image
            ax = axes[count]
            ax.imshow(img_np)
            ax.set_xticks([])
            ax.set_yticks([])

            # Draw bounding box
            rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor="r", facecolor="none")
            ax.add_patch(rect)

            # Add label
            label_text = "Cat" if sample_label == 1 else "Dog"
            ax.text(xmin, ymin - 5, label_text, color="white", fontsize=12,
                    bbox=dict(facecolor="red", alpha=0.5, edgecolor="none"))

            count += 1
            if count >= num_images:
                break
        
        if count >= num_images:
            break

    plt.tight_layout()
    plt.show()

# Call the function
visualize_grid(train_loader)
No description has been provided for this image
In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision.models as models

class PetClassifierAndBBox(pl.LightningModule):
    def __init__(self, lambda_bbox=5.0, lr=.001):
        """
        PyTorch Lightning module for pet classification and bounding box detection.
        
        - Uses EfficientNet as a feature extractor.
        - Two heads: 
            - One for classification (binary: cat/dog)
            - One for bounding box regression (x, y, width, height)
        - Loss: cross-entropy (classification) + lambda_bbox * MSE (bounding box)
        """
        super().__init__()
        self.save_hyperparameters()

        # Load pre-trained EfficientNet as feature extractor
        efficientnet = models.efficientnet_b0(pretrained=True)
        self.feature_extractor = efficientnet.features  # Remove classification head

        # Define a shared fully connected layer
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global Average Pooling
            nn.Flatten(),
            nn.Linear(1280, 512),  # EfficientNet-B0 has 1280 output features
            nn.ReLU()
        )

        # Define classification head (binary classification: cat vs. dog)
        self.classification_head = nn.Sequential(
            nn.Linear(512,512),
            nn.Linear(512, 2)  # Output 2 classes
        )

        # Define bounding box head (regression: x, y, width, height)
        self.bbox_head = nn.Sequential(
            nn.Linear(512,512),
            nn.Linear(512, 4)  # Output 4 coordinates
        )

        # Loss weights
        self.lambda_bbox = lambda_bbox
        self.lr = lr

    def forward(self, x):
        """Forward pass through feature extractor and both heads."""
        features = self.feature_extractor(x)  # Extract features
        features = self.fc(features)  # Pass through fully connected layer

        class_logits = self.classification_head(features)  # Classification head
        bbox_preds = self.bbox_head(features)  # Bounding box head

        return class_logits, bbox_preds

    def training_step(self, batch, batch_idx):
        """Training step: Compute loss and log metrics."""
        images, bboxes, labels = batch  # Unpack batch

        class_logits, bbox_preds = self(images)  # Forward pass

        # Compute losses
        loss_class = F.cross_entropy(class_logits, labels)  # Classification loss
        loss_bbox = F.mse_loss(bbox_preds, bboxes)  # Bounding box regression loss
        total_loss = loss_class + self.lambda_bbox * loss_bbox  # Combined loss

        # Logging
        self.log("train_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("train_class_loss", loss_class, prog_bar=True, on_step=False, on_epoch=True)
        self.log("train_bbox_loss", loss_bbox, prog_bar=True, on_step=False, on_epoch=True)

        return total_loss

    def validation_step(self, batch, batch_idx):
        """Validation step: Compute loss and log metrics."""
        images, bboxes, labels = batch

        class_logits, bbox_preds = self(images)

        loss_class = F.cross_entropy(class_logits, labels)
        loss_bbox = F.mse_loss(bbox_preds, bboxes)
        total_loss = loss_class + self.lambda_bbox * loss_bbox

        # Logging
        self.log("val_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_class_loss", loss_class, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_bbox_loss", loss_bbox, prog_bar=True, on_step=False, on_epoch=True)

        return total_loss

    def configure_optimizers(self):
        """Define optimizer and learning rate scheduler."""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
In [11]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
# Set up logging and early stopping
csv_logger = CSVLogger(save_dir='logs/', name='SingleDetector', version="")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=25, verbose=True, mode="min")

# Create the model instance
model = PetClassifierAndBBox(lambda_bbox = 5)

# Assume train_loader and val_loader are defined DataLoaders
trainer = pl.Trainer(
    max_epochs=50,
    logger=csv_logger,
    callbacks=[early_stop_callback]
)

trainer.fit(model, train_loader, val_loader)

# Save the final model state
trainer.save_checkpoint('logs/SingleDetector/final_model.ckpt')
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type       | Params | Mode 
-----------------------------------------------------------
0 | feature_extractor   | Sequential | 4.0 M  | train
1 | fc                  | Sequential | 655 K  | train
2 | classification_head | Sequential | 263 K  | train
3 | bbox_head           | Sequential | 264 K  | train
-----------------------------------------------------------
5.2 M     Trainable params
0         Non-trainable params
5.2 M     Total params
20.767    Total estimated model params size (MB)
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Training: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved. New best score: 0.189
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.181
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.168
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.027 >= min_delta = 0.0. New best score: 0.141
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.042 >= min_delta = 0.0. New best score: 0.099
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.090
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.089
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.089
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.079
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.076
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.070
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.067
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.065
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=50` reached.
In [12]:
# Load the final model checkpoint
model = PetClassifierAndBBox.load_from_checkpoint('logs/SingleDetector/final_model.ckpt')

# Set the model to evaluation mode
model.eval()

def visualize_predictions(dataloader, model, num_images=9):
    """
    Display a 3x3 grid of images with true and predicted bounding boxes, predicted class labels, and probabilities.
    """
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))  # 3x3 grid
    axes = axes.flatten()  # Flatten to iterate easily

    count = 0
    for images, true_bboxes, true_labels in dataloader:
        with torch.no_grad():
            class_logits, pred_bboxes = model(images)
            pred_probs = torch.softmax(class_logits, dim=1)
            pred_labels = torch.argmax(class_logits, dim=1)

        for i in range(min(num_images, len(images))):
            if count >= num_images:
                break

            sample_img = images[i]
            true_bbox = true_bboxes[i]  # Format: (xmin, ymin, xmax, ymax)
            pred_bbox = pred_bboxes[i]  # Format: (xmin, ymin, xmax, ymax)
            pred_label = pred_labels[i].item()
            pred_prob = pred_probs[i][pred_label].item()

            # Convert tensor image back to numpy for visualization
            img_np = sample_img.permute(1, 2, 0).numpy()

            # Denormalize true bbox coordinates (based on resized 224x224 image)
            true_xmin = true_bbox[0] * 224
            true_ymin = true_bbox[1] * 224
            true_xmax = true_bbox[2] * 224
            true_ymax = true_bbox[3] * 224

            # Compute width and height for true bbox
            true_width = true_xmax - true_xmin
            true_height = true_ymax - true_ymin

            # Denormalize predicted bbox coordinates (based on resized 224x224 image)
            pred_xmin = pred_bbox[0] * 224
            pred_ymin = pred_bbox[1] * 224
            pred_xmax = pred_bbox[2] * 224
            pred_ymax = pred_bbox[3] * 224

            # Compute width and height for predicted bbox
            pred_width = pred_xmax - pred_xmin
            pred_height = pred_ymax - pred_ymin

            # Plot the image
            ax = axes[count]
            ax.imshow(img_np)
            ax.set_xticks([])
            ax.set_yticks([])

            # Draw true bounding box
            true_rect = patches.Rectangle((true_xmin, true_ymin), true_width, true_height, linewidth=2, edgecolor="g", facecolor="none")
            ax.add_patch(true_rect)

            # Draw predicted bounding box
            pred_rect = patches.Rectangle((pred_xmin, pred_ymin), pred_width, pred_height, linewidth=2, edgecolor="r", facecolor="none")
            ax.add_patch(pred_rect)

            # Add predicted label and probability
            label_text = f"{'Cat' if pred_label == 1 else 'Dog'}: {pred_prob:.2f}"
            ax.text(pred_xmin, pred_ymin - 10, label_text, color="white", fontsize=12, bbox=dict(facecolor="red", alpha=0.5, edgecolor="none"))

            count += 1
            if count >= num_images:
                break

        if count >= num_images:
            break

    plt.tight_layout()
    plt.show()

# Call the function to visualize predictions
visualize_predictions(val_loader, model)
No description has been provided for this image